{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Analysis Simple Agent with PydanticAI\n",
    "\n",
    "**This tutorial is based on the LangChain tutorial: `Data Analysis Simple Agent`. It demonstrates the same concept using PydanticAI as the agent framework.**\n",
    "\n",
    "**You don’t need to be familiar with the LangChain notebook to follow along—this tutorial stands on its own and explains everything you need to know.** For more information about PydanticAI, visit their [official website](https://ai.pydantic.dev/), or check out the PydanticAI Overview in [this notebook](https://github.com/NirDiamant/GenAI_Agents/blob/main/all_agents_tutorials/simple_conversational_agent-pydanticai.ipynb).\n",
    "\n",
    "In this version of the notebook, we replicate the [Data Analysis Simple Agent](https://github.com/NirDiamant/GenAI_Agents/blob/main/all_agents_tutorials/simple_data_analysis_agent_notebook.ipynb) workflow using **PydanticAI**. The primary difference is that LangChain includes a built-in agent designed to handle Pandas DataFrames and perform actions on them directly. PydanticAI, being a newer framework, does not yet include such a built-in tool. As a result, we’ll create the tool ourselves, providing an opportunity to explore how to build custom tools and implement retry logic with PydanticAI.\n",
    "\n",
    "## Overview\n",
    "This tutorial guides you through creating an AI-powered data analysis agent that can interpret and answer questions about a dataset using natural language. It combines language models with data manipulation tools to enable intuitive data exploration.\n",
    "\n",
    "## Motivation\n",
    "Data analysis often requires specialized knowledge, limiting access to insights for non-technical users. By creating an AI agent that understands natural language queries, we can democratize data analysis, allowing anyone to extract valuable information from complex datasets without needing to know programming or statistical tools.\n",
    "\n",
    "## Key Components\n",
    "1. Language Model: Processes natural language queries and generates human-like responses\n",
    "2. Data Manipulation Framework: Handles dataset operations and analysis\n",
    "3. Agent Framework: Connects the language model with data manipulation tools\n",
    "4. Synthetic Dataset: Represents real-world data for demonstration purposes\n",
    "\n",
    "## Method\n",
    "1. Create a synthetic dataset representing car sales data\n",
    "2. Construct an agent that combines the language model with data analysis capabilities\n",
    "3. Create a tool that the agent can use to query our dataset.\n",
    "4. Implement a query processing function to handle natural language questions\n",
    "5. Demonstrate the agent's abilities with example queries\n",
    "\n",
    "## Conclusion\n",
    "This approach to data analysis offers significant benefits:\n",
    "- Accessibility for non-technical users\n",
    "- Flexibility in handling various query types\n",
    "- Efficient ad-hoc data exploration\n",
    "\n",
    "By making data insights more accessible, this method has the potential to transform how organizations leverage their data for decision-making across various fields and industries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries and set environment variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install -q pydantic-ai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from datetime import datetime, timedelta\n",
    "from dotenv import load_dotenv\n",
    "from typing import Any\n",
    "\n",
    "from pydantic_ai import Agent, RunContext, ModelRetry\n",
    "from pydantic_ai.messages import Message, MessagesTypeAdapter\n",
    "from pydantic_ai.result import RunResult\n",
    "\n",
    "# Set a random seed for reproducibility\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply `nest_asyncio` to avoid errors when running asyncio code in a Jupyter notebook.\n",
    "# This prevents `event loop is already running` errors by allowing nested event loops.\n",
    "\n",
    "import nest_asyncio\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load environment\n",
    "load_dotenv()\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')\n",
    "os.environ['LOGFIRE_IGNORE_NO_CONFIG'] = '1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Sample Data\n",
    "\n",
    "In this section, we create a sample dataset of car sales. This includes generating dates, car makes, models, colors, and other relevant information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "First few rows of the generated data:\n",
      "        Date       Make      Model  Color  Year     Price  Mileage  \\\n",
      "0 2022-01-01   Mercedes      Sedan  Green  2022  57952.65   5522.0   \n",
      "1 2022-01-02  Chevrolet  Hatchback    Red  2021  58668.22  94238.0   \n",
      "2 2022-01-03       Audi      Truck  White  2019  69187.87   7482.0   \n",
      "3 2022-01-04     Nissan  Hatchback  Black  2016  40004.44  43846.0   \n",
      "4 2022-01-05   Mercedes  Hatchback    Red  2016  63983.07  52988.0   \n",
      "\n",
      "   EngineSize  FuelEfficiency SalesPerson  \n",
      "0         2.0            24.7       Alice  \n",
      "1         1.6            26.2         Bob  \n",
      "2         2.0            28.0       David  \n",
      "3         3.5            24.8       David  \n",
      "4         2.5            24.1       Alice  \n",
      "\n",
      "DataFrame info:\n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 1000 entries, 0 to 999\n",
      "Data columns (total 10 columns):\n",
      " #   Column          Non-Null Count  Dtype         \n",
      "---  ------          --------------  -----         \n",
      " 0   Date            1000 non-null   datetime64[ns]\n",
      " 1   Make            1000 non-null   object        \n",
      " 2   Model           1000 non-null   object        \n",
      " 3   Color           1000 non-null   object        \n",
      " 4   Year            1000 non-null   int64         \n",
      " 5   Price           1000 non-null   float64       \n",
      " 6   Mileage         1000 non-null   float64       \n",
      " 7   EngineSize      1000 non-null   float64       \n",
      " 8   FuelEfficiency  1000 non-null   float64       \n",
      " 9   SalesPerson     1000 non-null   object        \n",
      "dtypes: datetime64[ns](1), float64(4), int64(1), object(4)\n",
      "memory usage: 78.3+ KB\n",
      "\n",
      "Summary statistics:\n",
      "                      Date         Year         Price       Mileage  \\\n",
      "count                 1000  1000.000000   1000.000000   1000.000000   \n",
      "mean   2023-05-15 12:00:00  2018.445000  51145.360800  48484.643000   \n",
      "min    2022-01-01 00:00:00  2015.000000  20026.570000     19.000000   \n",
      "25%    2022-09-07 18:00:00  2017.000000  36859.940000  23191.500000   \n",
      "50%    2023-05-15 12:00:00  2018.000000  52215.155000  47506.000000   \n",
      "75%    2024-01-20 06:00:00  2020.000000  65741.147500  73880.250000   \n",
      "max    2024-09-26 00:00:00  2022.000000  79972.640000  99762.000000   \n",
      "std                    NaN     2.256117  17041.610861  29103.404593   \n",
      "\n",
      "        EngineSize  FuelEfficiency  \n",
      "count  1000.000000     1000.000000  \n",
      "mean      2.744500       29.688500  \n",
      "min       1.600000       20.000000  \n",
      "25%       2.000000       24.500000  \n",
      "50%       2.500000       29.700000  \n",
      "75%       3.500000       34.700000  \n",
      "max       4.000000       40.000000  \n",
      "std       0.839389        5.896316  \n"
     ]
    }
   ],
   "source": [
    "# Generate sample data\n",
    "n_rows = 1000\n",
    "\n",
    "# Generate dates\n",
    "start_date = datetime(2022, 1, 1)\n",
    "dates = [start_date + timedelta(days=i) for i in range(n_rows)]\n",
    "\n",
    "# Define data categories\n",
    "makes = ['Toyota', 'Honda', 'Ford', 'Chevrolet', 'Nissan', 'BMW', 'Mercedes', 'Audi', 'Hyundai', 'Kia']\n",
    "models = ['Sedan', 'SUV', 'Truck', 'Hatchback', 'Coupe', 'Van']\n",
    "colors = ['Red', 'Blue', 'Black', 'White', 'Silver', 'Gray', 'Green']\n",
    "\n",
    "# Create the dataset\n",
    "data = {\n",
    "    'Date': dates,\n",
    "    'Make': np.random.choice(makes, n_rows),\n",
    "    'Model': np.random.choice(models, n_rows),\n",
    "    'Color': np.random.choice(colors, n_rows),\n",
    "    'Year': np.random.randint(2015, 2023, n_rows),\n",
    "    'Price': np.random.uniform(20000, 80000, n_rows).round(2),\n",
    "    'Mileage': np.random.uniform(0, 100000, n_rows).round(0),\n",
    "    'EngineSize': np.random.choice([1.6, 2.0, 2.5, 3.0, 3.5, 4.0], n_rows),\n",
    "    'FuelEfficiency': np.random.uniform(20, 40, n_rows).round(1),\n",
    "    'SalesPerson': np.random.choice(['Alice', 'Bob', 'Charlie', 'David', 'Eva'], n_rows)\n",
    "}\n",
    "\n",
    "# Create DataFrame and sort by date\n",
    "df = pd.DataFrame(data).sort_values('Date')\n",
    "\n",
    "# Display sample data and statistics\n",
    "print(\"\\nFirst few rows of the generated data:\")\n",
    "print(df.head())\n",
    "\n",
    "print(\"\\nDataFrame info:\")\n",
    "df.info()\n",
    "\n",
    "print(\"\\nSummary statistics:\")\n",
    "print(df.describe())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Data Analysis Agent\n",
    "\n",
    "Unlike the [LangChain notebook](https://github.com/NirDiamant/GenAI_Agents/blob/main/all_agents_tutorials/simple_data_analysis_agent_notebook.ipynb) on which this example is based, PydanticAI does not (yet?) have a built-in tool for processing Pandas DataFrames.\n",
    "\n",
    "To address this, we’ll create a custom tool that implements the required functionality for our example.\n",
    "\n",
    "We’ll begin by defining the agent itself, along with the dependencies that the tool will need. The tool implementation will follow in the next section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Dependencies\n",
    "\n",
    "PydanticAI uses a [dependency injection system](https://ai.pydantic.dev/dependencies/) to provide data and services to an agent’s system prompts, tools, and result validators.\n",
    "\n",
    "We’ll use this system to define the DataFrame as a dependency, allowing us to reference it inside the tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class Deps:\n",
    "    \"\"\"The only dependency we need is the DataFrame we'll be working with.\"\"\"\n",
    "    df: pd.DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = Agent(\n",
    "    model='openai:gpt-4o-mini',\n",
    "    system_prompt=\"\"\"You are an AI assistant that helps extract information from a pandas DataFrame.\n",
    "    If asked about columns, be sure to check the column names first.\n",
    "    Be concise in your answers.\"\"\",\n",
    "    deps_type=Deps,\n",
    "\n",
    "    # Allow the agent to make mistakes and correct itself. Details will be covered in the tool definition.\n",
    "    retries=10,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a Tool to Query the DataFrame\n",
    "\n",
    "Our tool is straightforward. Unlike the LangChain function `create_pandas_dataframe_agent`, which you can see in the [LangChain notebook](https://github.com/NirDiamant/GenAI_Agents/blob/main/all_agents_tutorials/simple_data_analysis_agent_notebook.ipynb) and uses the Python REPL and can be dangerous, we run our queries using `pd.eval`.\n",
    "\n",
    "`pd.eval` allows execution of only a subset of Pandas commands, limiting the potential for malicious code execution.\n",
    "\n",
    "The downside is that this approach supports only a subset of the regular Pandas syntax, which means the agent may occasionally make mistakes by using unsupported syntax. To handle such cases, we enable retries. During the agent’s definition, we set the number of allowed retries to 10. If an error occurs during tool execution, we raise a `ModelRetry` exception to prompt the agent to retry its query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "@agent.tool\n",
    "async def df_query(ctx: RunContext[Deps], query: str) -> str:\n",
    "    \"\"\"A tool for running queries on the `pandas.DataFrame`. Use this tool to interact with the DataFrame.\n",
    "\n",
    "    `query` will be executed using `pd.eval(query, target=df)`, so it must contain syntax compatible with\n",
    "    `pandas.eval`.\n",
    "    \"\"\"\n",
    "\n",
    "    # Print the query for debugging purposes and fun :)\n",
    "    print(f'Running query: `{query}`')\n",
    "    try:\n",
    "        # Execute the query using `pd.eval` and return the result as a string (must be serializable).\n",
    "        return str(pd.eval(query, target=ctx.deps.df))\n",
    "    except Exception as e:\n",
    "        #  On error, raise a `ModelRetry` exception with feedback for the agent.\n",
    "        raise ModelRetry(f'query: `{query}` is not a valid query. Reason: `{e}`') from e"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Question-Asking Function\n",
    "\n",
    "This function allows us to easily ask questions to our data analysis agent and display the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask_agent(question):\n",
    "    \"\"\"Function to ask questions to the agent and display the response\"\"\"\n",
    "    deps = Deps(df=df)\n",
    "    print(f\"Question: {question}\")\n",
    "    response = agent.run_sync(question, deps=deps)\n",
    "    print(f\"Answer: {response.new_messages()[-1].content}\")\n",
    "    print(\"---\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example Questions\n",
    "\n",
    "Here are some example questions you can ask the data analysis agent. You can modify these or add your own questions to analyze the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: What are the column names in this dataset?\n",
      "Running query: `df.columns.tolist()`\n",
      "Answer: The column names in the dataset are: \n",
      "\n",
      "- Date\n",
      "- Make\n",
      "- Model\n",
      "- Color\n",
      "- Year\n",
      "- Price\n",
      "- Mileage\n",
      "- EngineSize\n",
      "- FuelEfficiency\n",
      "- SalesPerson\n",
      "---\n",
      "Question: How many rows are in this dataset?\n",
      "Running query: `len(df)`\n",
      "Running query: `df.shape[0]`\n",
      "Answer: The dataset contains 1,000 rows.\n",
      "---\n",
      "Question: What is the average price of cars sold?\n",
      "Running query: `cars['price'].mean()`\n",
      "Running query: `df['price'].mean()`\n",
      "Running query: `df.columns`\n",
      "Running query: `df['Price'].mean()`\n",
      "Answer: The average price of cars sold is approximately $51,145.36.\n",
      "---\n"
     ]
    }
   ],
   "source": [
    "# Example questions\n",
    "ask_agent(\"What are the column names in this dataset?\")\n",
    "ask_agent(\"How many rows are in this dataset?\")\n",
    "ask_agent(\"What is the average price of cars sold?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Analysis of Examples\n",
    "\n",
    "As you can see in the above example, the agent got the column names right away but needed to retry a few times before arriving at the correct syntax to query the number of rows and the average price.\n",
    "\n",
    "The primary issue was that the `Price` column name starts with a capital `P`, which caused some retries when querying the average price. We could improve the agent’s performance by including additional context in the prompt, such as column names, types, or usage examples, to help the agent arrive at correct answers more efficiently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
